iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 36

JAX 好好玩 (36) : Flax (2) : 第一個範例程式

  • 分享至 

  • xImage
  •  

這一個範例程式的目的,是給大家一個整體的概念,看看要設計及訓練一個 Flax 神經網路模型要做那些事情。老頭改寫了 Flax 官網文件上的 Getting Started 程式,希望能夠給讀者們一個更清晰的程式架構。程式的 colab 檔可以由此下載,大家先去跑一次,再讀讀程式中的註解,應該就可以對於 Flax 有些初步的認識。

之前老頭曾以 Pytorch 來載入 MNIST 資料集,這一回 Flax 選用 TensorFlow 的 dataset 服務來載入,做為訓練的標的。目前 JAX 的生態系裏還沒有看到自行定義的 dataset 服務,應該也沒有什麼必要自己再弄一套,利用現有的就可以了。

這個範例老頭分別以 Flax 中的「完整寫法」及「精簡寫法」[36.1]設計兩個結構完全一樣的 CNN 神經網路,之後分別訓練這兩個網路作為比較。這個設計有點模仿 Keras 中的 Sequential 機制,對於簡單的模型,「精簡寫法」更為簡潔。

另外要注意的是「Flax 模型的參數是獨立於模型之外的」!原因在於必須維持模型運算的「純粹 pure」,相同的輸入,得到相同的輸出 (讀者可以參考老頭先前有關純函式的貼文)。模型本身只保留運算流程,而把訓練 (或推理)資料及參數當做輸入,那麼整個模型的計算,就能維持純函式的特性,利於 JIT 編譯以加快運算速度。

有了整體的感覺之後,接著就可以進入細節了。

[36.1] Flax 正式的名稱為「明確的 explicitly」宣告法及「行內的 in-line」宣告法。


上一篇
JAX 好好玩 (35) : Flax (1) : 準備學習 Flax
下一篇
JAX 好好玩 (37) : Flax (3) : 第二個範例程式
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言